"""Build .docx from paper.md + all output tables."""

import re
from pathlib import Path
from docx import Document
from docx.shared import Inches, Pt, Cm, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.table import WD_TABLE_ALIGNMENT

PROJECT = Path("/mnt/c/demographics_capital_flows/monetary")
PAPER_MD = PROJECT / "paper" / "paper.md"
TABLES_DIR = PROJECT / "output" / "tables"
FIGURES_DIR = PROJECT / "output" / "figures"
OUTPUT = PROJECT / "paper" / "monetary_paper.docx"

TABLE_FILES = [
    ("Table 1: Rate Levels", "phase2_table1_rate_levels.md"),
    ("Table 2: Inflation Baseline", "phase2_table2_inflation.md"),
    ("Table 3: Phillips Curve Interactions", "phase3_table3_phillips_interactions.md"),
    ("Table 4: Phillips Curve by Tercile", "phase3_table4_phillips_terciles.md"),
    ("Table 5: Transmission — Growth", "phase4_table5_transmission_growth.md"),
    ("Table 5b: Transmission — Investment Channel", "phase4_table5b_transmission_investment.md"),
    ("Table 5c: Transmission — Consumption Channel", "phase4_table5c_transmission_consumption.md"),
    ("Table 6: Transmission — Inflation", "phase4_table6_transmission_inflation.md"),
    ("Table 7: CBI Interactions", "phase5_table7_cbi_interactions.md"),
    ("Table 8: IT Interactions", "phase5_table8_it_interactions.md"),
    ("Table 9: Combined Regimes", "phase5_table9_combined_regimes.md"),
    ("Table 10: Global vs Domestic Z", "phase6_table10_global_domestic.md"),
    ("Table 11: World Time Series", "phase6_table11_world_ts.md"),
    ("Table 12: Chow Test", "phase7_table12_chow.md"),
    ("Table 13: QE Interaction", "phase7_table13_qe.md"),
    ("Table 13b: QE + Japanification Controls", "phase7_table13b_japan_controls.md"),
    ("Table 13c: Japanification vs QE Horse Race", "phase7_table13c_japan_horserace.md"),
    ("Table 14: ZLB Interaction", "phase7_table14_zlb.md"),
    ("Table 15: Rolling Windows", "phase7_table15_rolling.md"),
    ("Table 16: Post-QE Regressions", "phase8_table16_post_qe.md"),
    ("Table 17: Out-of-Sample", "phase8_table17_oos.md"),
    ("Table 18: Re-emergence", "phase8_table18_reemergence.md"),
    ("Table 19: Country Predictions", "phase8_table19_country.md"),
    ("Table 20: Robustness — Rates", "phase9_table20_robustness_rates.md"),
    ("Table 21: Robustness — Inflation", "phase9_table21_robustness_inflation.md"),
]


def set_cell_text(cell, text, bold=False, size=Pt(9), font_name='Times New Roman'):
    """Set cell text with formatting."""
    cell.text = ""
    p = cell.paragraphs[0]
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    run = p.add_run(text)
    run.font.size = size
    run.font.name = font_name
    run.bold = bold


def add_md_table(doc, md_text, title=None):
    """Parse pipe-separated markdown table and add to document."""
    lines = [l.strip() for l in md_text.strip().split('\n') if l.strip()]
    table_lines = [l for l in lines if '|' in l and not l.startswith('#')]
    if not table_lines:
        return

    rows = []
    for line in table_lines:
        cells = [c.strip() for c in line.split('|')]
        cells = [c for c in cells if c != '']
        if all(set(c) <= set('-: ') for c in cells):
            continue
        rows.append(cells)

    if len(rows) < 2:
        return

    if title:
        p = doc.add_paragraph()
        run = p.add_run(title)
        run.bold = True
        run.font.size = Pt(11)
        run.font.name = 'Times New Roman'

    n_cols = max(len(r) for r in rows)
    table = doc.add_table(rows=len(rows), cols=n_cols)
    table.style = 'Light Shading'
    table.alignment = WD_TABLE_ALIGNMENT.CENTER

    for i, row_data in enumerate(rows):
        for j, cell_text in enumerate(row_data):
            if j < n_cols:
                set_cell_text(table.cell(i, j), cell_text, bold=(i == 0))

    doc.add_paragraph()


def parse_md_file(filepath):
    """Parse markdown file into sections (headings, paragraphs, tables)."""
    text = filepath.read_text()
    sections = []
    lines = text.split('\n')
    current_title = None
    table_buf = []

    for line in lines:
        if line.startswith('#'):
            if table_buf:
                sections.append(('table', current_title, '\n'.join(table_buf)))
                table_buf = []
            current_title = re.sub(r'^#+\s*', '', line).strip()
        elif '|' in line:
            table_buf.append(line)
        elif line.strip().startswith('*') and not table_buf:
            sections.append(('note', line.strip()))

    if table_buf:
        sections.append(('table', current_title, '\n'.join(table_buf)))

    return sections


def sanitize_math(math_text):
    """Sanitize LaTeX math for plain-text rendering."""
    for old, new in [('\\text{', ''), ('}', ''), ('\\log', 'log'),
                     ('\\cdot', '\u00b7'), ('\\cdots', '...'),
                     ('\\varepsilon', '\u03b5'), ('\\alpha', '\u03b1'),
                     ('\\beta', '\u03b2'), ('\\gamma', '\u03b3'),
                     ('\\delta', '\u03b4'), ('\\Delta', '\u0394'),
                     ('\\hat', ''), ('\\widehat', ''),
                     ('\\sum', '\u03a3'), ('\\exp', 'exp'),
                     ('\\times', '\u00d7'), ('\\pi', '\u03c0'),
                     ('\\phi', '\u03c6'), ('\\bar', ''),
                     ('\\overline', '')]:
        math_text = math_text.replace(old, new)
    # Strip remaining LaTeX commands
    math_text = re.sub(r'\\[a-zA-Z]+', '', math_text)
    math_text = re.sub(r'_\{([^}]+)\}', r'_\1', math_text)
    math_text = re.sub(r'\^\{([^}]+)\}', r'^\1', math_text)
    # Filter control characters
    math_text = ''.join(c for c in math_text if ord(c) >= 32 or c in '\n\t')
    return math_text


def add_formatted_paragraph(doc, text, style=None):
    """Handle **bold**, *italic*, $math$ in paragraph text."""
    p = doc.add_paragraph(style=style)
    parts = re.split(r'(\*\*\*[^*]+\*\*\*|\*\*[^*]+\*\*|\*[^*]+\*|\$[^$]+\$)', text)
    for part in parts:
        if part.startswith('***') and part.endswith('***'):
            run = p.add_run(part[3:-3])
            run.bold = True
            run.italic = True
            run.font.name = 'Times New Roman'
            run.font.size = Pt(11)
        elif part.startswith('**') and part.endswith('**'):
            run = p.add_run(part[2:-2])
            run.bold = True
            run.font.name = 'Times New Roman'
            run.font.size = Pt(11)
        elif part.startswith('*') and part.endswith('*') and len(part) > 2:
            run = p.add_run(part[1:-1])
            run.italic = True
            run.font.name = 'Times New Roman'
            run.font.size = Pt(11)
        elif part.startswith('$') and part.endswith('$'):
            math = part[1:-1]
            for old, new in [('\\hat{\\beta}', '\u03b2\u0302'),
                             ('\\beta', '\u03b2'), ('\\Delta', '\u0394'),
                             ('\\times', '\u00d7'), ('\\gamma', '\u03b3'),
                             ('\\pi', '\u03c0'), ('\\phi', '\u03c6'),
                             ('\\bar', ''), ('\\alpha', '\u03b1'),
                             ('\\varepsilon', '\u03b5'), ('\\delta', '\u03b4')]:
                math = math.replace(old, new)
            math = re.sub(r'_\{([^}]+)\}', r'_\1', math)
            math = re.sub(r'\^\{([^}]+)\}', r'^\1', math)
            math = re.sub(r'\\[a-zA-Z]+', '', math)
            math = ''.join(c for c in math if ord(c) >= 32 or c in '\n\t')
            run = p.add_run(math)
            run.italic = True
            run.font.name = 'Cambria Math'
            run.font.size = Pt(10)
        else:
            run = p.add_run(part)
            run.font.name = 'Times New Roman'
            run.font.size = Pt(11)
    return p


def build_docx():
    doc = Document()

    # Set 1-inch margins
    for section in doc.sections:
        section.top_margin = Cm(2.54)
        section.bottom_margin = Cm(2.54)
        section.left_margin = Cm(2.54)
        section.right_margin = Cm(2.54)

    # Set default style to 11pt
    style = doc.styles['Normal']
    style.font.name = 'Times New Roman'
    style.font.size = Pt(11)
    style.paragraph_format.space_after = Pt(6)
    style.paragraph_format.line_spacing = 1.15

    paper_text = PAPER_MD.read_text()
    lines = paper_text.split('\n')

    i = 0
    while i < len(lines):
        line = lines[i]

        # Title (h1)
        if line.startswith('# ') and not line.startswith('## '):
            title = line[2:].strip()
            p = doc.add_heading(title, level=0)
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            for run in p.runs:
                run.font.size = Pt(16)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)
            i += 1
            continue

        # Section heading (h2)
        if line.startswith('## '):
            heading = line[3:].strip()
            p = doc.add_heading(heading, level=1)
            for run in p.runs:
                run.font.size = Pt(13)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)
            i += 1
            continue

        # Subsection heading (h3)
        if line.startswith('### '):
            heading = line[4:].strip()
            p = doc.add_heading(heading, level=2)
            for run in p.runs:
                run.font.size = Pt(11)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)
            i += 1
            continue

        # Display math ($$...$$)
        if line.strip().startswith('$$'):
            math_lines = [line.strip().replace('$$', '')]
            i += 1
            while i < len(lines) and '$$' not in lines[i]:
                math_lines.append(lines[i].strip())
                i += 1
            if i < len(lines):
                math_lines.append(lines[i].strip().replace('$$', ''))
                i += 1
            math_text = ' '.join(l for l in math_lines if l)
            math_text = sanitize_math(math_text)
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            run = p.add_run(math_text)
            run.font.size = Pt(10)
            run.font.name = 'Cambria Math'
            run.italic = True
            continue

        # Inline table
        if '|' in line and line.strip().startswith('|'):
            table_lines = []
            while i < len(lines) and '|' in lines[i]:
                table_lines.append(lines[i])
                i += 1
            add_md_table(doc, '\n'.join(table_lines))
            continue

        # Table placeholder [TABLE X]
        if line.strip().startswith('[TABLE'):
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            run = p.add_run(line.strip())
            run.italic = True
            run.font.size = Pt(11)
            run.font.name = 'Times New Roman'
            run.font.color.rgb = RGBColor(128, 128, 128)
            i += 1
            continue

        # Regular paragraph
        if line.strip():
            add_formatted_paragraph(doc, line.strip())

        i += 1

    # ---- Appendix: Tables ----
    doc.add_page_break()
    p = doc.add_heading('Appendix: Tables', level=0)
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    for run in p.runs:
        run.font.size = Pt(16)
        run.font.name = 'Times New Roman'
        run.font.color.rgb = RGBColor(0, 0, 0)

    for label, filename in TABLE_FILES:
        filepath = TABLES_DIR / filename
        if not filepath.exists():
            continue

        doc.add_page_break()
        p = doc.add_heading(label, level=1)
        for run in p.runs:
            run.font.size = Pt(13)
            run.font.name = 'Times New Roman'
            run.font.color.rgb = RGBColor(0, 0, 0)

        sections = parse_md_file(filepath)
        for section in sections:
            if section[0] == 'table':
                _, title, md = section
                if title and title != label:
                    add_md_table(doc, md, title=title)
                else:
                    add_md_table(doc, md)
            elif section[0] == 'note':
                p = doc.add_paragraph()
                run = p.add_run(section[1])
                run.italic = True
                run.font.size = Pt(9)
                run.font.name = 'Times New Roman'

    # ---- Appendix: Figures ----
    if FIGURES_DIR.exists():
        figure_files = sorted(
            [f for f in FIGURES_DIR.iterdir()
             if f.suffix.lower() in ('.png', '.jpg', '.jpeg', '.pdf')]
        )
        if figure_files:
            doc.add_page_break()
            p = doc.add_heading('Appendix: Figures', level=0)
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            for run in p.runs:
                run.font.size = Pt(16)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)

            for fig_path in figure_files:
                if fig_path.suffix.lower() == '.pdf':
                    # Skip PDF figures (not supported by python-docx)
                    p = doc.add_paragraph()
                    run = p.add_run(f"[Figure: {fig_path.name} — PDF not embedded]")
                    run.italic = True
                    run.font.size = Pt(10)
                    continue

                # Add figure title from filename
                fig_title = fig_path.stem.replace('_', ' ').title()
                p = doc.add_paragraph()
                p.alignment = WD_ALIGN_PARAGRAPH.CENTER
                run = p.add_run(fig_title)
                run.bold = True
                run.font.size = Pt(11)
                run.font.name = 'Times New Roman'

                # Add the image
                try:
                    p = doc.add_paragraph()
                    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
                    run = p.add_run()
                    run.add_picture(str(fig_path), width=Inches(5.5))
                except Exception as e:
                    p = doc.add_paragraph()
                    run = p.add_run(f"[Could not embed {fig_path.name}: {e}]")
                    run.italic = True
                    run.font.size = Pt(9)

                doc.add_paragraph()  # spacing

    # Save
    doc.save(str(OUTPUT))
    print(f"Saved: {OUTPUT}")
    print(f"Size: {OUTPUT.stat().st_size / 1024:.0f} KB")


if __name__ == '__main__':
    build_docx()
